{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SDPA operation using cudnn FE and serialization\n",
    "This notebook shows how a sdpa operation can be done using cudnn and how to serialize and deserialize the graph."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/02_sdpa_graph_serialization.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If running on Colab, you will need to install the cudnn python interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('export CUDA_VERSION=\"12.3\"')\n",
    "# get_ipython().system('pip install nvidia-cudnn-cu12')\n",
    "# get_ipython().system('CUDNN_PATH=`pip show nvidia-cudnn-cu12  | grep Location | cut -d\":\" -f2 | xargs`/nvidia/cudnn pip install git+https://github.com/NVIDIA/cudnn-frontend.git')\n",
    "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### General Setup\n",
    "We are going to call the cudnn through torch in this example. In general any dlpack tensor should work.\n",
    "cudnn handle is a per device handle used to initialize cudnn context.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cudnn\n",
    "import torch\n",
    "from enum import Enum\n",
    "\n",
    "handle = cudnn.create_handle()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Problem definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "b = 2 # batch size\n",
    "\n",
    "s_q  = 1024 # query sequence length\n",
    "s_kv = 1024 # key+value sequence length\n",
    "\n",
    "h = 6 # Query heads\n",
    "\n",
    "d = 64   # query+key embedding dimension per head\n",
    "\n",
    "shape_q = (b, h, s_q, d)\n",
    "shape_k = (b, h, s_kv, d)\n",
    "shape_v = (b, h, s_kv, d)\n",
    "shape_o = (b, h, s_q, d)\n",
    "\n",
    "stride_q = (s_q  * h * d, d, h * d, 1)\n",
    "stride_k = (s_kv * h * d, d, h * d, 1)\n",
    "stride_v = (s_kv * h * d, d, h * d, 1)\n",
    "stride_o = (s_q  * h * d, d, h * d, 1)\n",
    "\n",
    "attn_scale = 0.125\n",
    "\n",
    "q_gpu     = torch.randn(b * h * s_q * d, dtype=torch.bfloat16, device=\"cuda\").as_strided(shape_q, stride_q)\n",
    "k_gpu     = torch.randn(b * h * s_kv * d, dtype=torch.bfloat16, device=\"cuda\").as_strided(shape_k, stride_k)\n",
    "v_gpu     = torch.randn(b * h * s_kv * d, dtype=torch.bfloat16, device=\"cuda\").as_strided(shape_v, stride_v)\n",
    "o_gpu     = torch.empty(b * h * s_q * d, dtype=torch.bfloat16, device=\"cuda\").as_strided(shape_o, stride_o)\n",
    "stats_gpu = torch.empty(b, h, s_q, 1, dtype=torch.float32, device=\"cuda\")\n",
    "\n",
    "class UIDs(Enum):\n",
    "    Q_UID     = 0\n",
    "    K_UID     = 1\n",
    "    V_UID     = 2\n",
    "    O_UID     = 3\n",
    "    STATS_UID = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Graph build helper\n",
    "This will called by check_support and serialize function to build the sdpa graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_and_validate_graph_helper():\n",
    "    graph = cudnn.pygraph(\n",
    "        io_data_type=cudnn.data_type.HALF,\n",
    "        intermediate_data_type=cudnn.data_type.FLOAT,\n",
    "        compute_data_type=cudnn.data_type.FLOAT,\n",
    "        handle = handle)\n",
    "    \n",
    "    q = graph.tensor_like(q_gpu)\n",
    "    k = graph.tensor_like(k_gpu)\n",
    "    v = graph.tensor_like(v_gpu)\n",
    "    \n",
    "    o, stats = graph.sdpa(name=\"sdpa\",\n",
    "        q=q, k=k, v=v,\n",
    "        is_inference=False,\n",
    "        attn_scale=attn_scale,\n",
    "        use_causal_mask=True)\n",
    "    \n",
    "    o.set_output(True).set_dim(shape_o).set_stride(stride_o)\n",
    "    stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)\n",
    "    \n",
    "    q.set_uid(UIDs.Q_UID.value)\n",
    "    k.set_uid(UIDs.K_UID.value)\n",
    "    v.set_uid(UIDs.V_UID.value)\n",
    "    o.set_uid(UIDs.O_UID.value)\n",
    "    stats.set_uid(UIDs.STATS_UID.value)\n",
    "    \n",
    "    graph.validate()\n",
    "    \n",
    "    return graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Check support "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_support():\n",
    "    \n",
    "    graph = build_and_validate_graph_helper()\n",
    "    \n",
    "    graph.build_operation_graph()\n",
    "    \n",
    "    graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n",
    "\n",
    "    graph.check_support()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Serialization function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def serialize():\n",
    "    graph = build_and_validate_graph_helper()\n",
    "    \n",
    "    graph.build_operation_graph()\n",
    "    \n",
    "    graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n",
    "\n",
    "    graph.check_support()\n",
    "    \n",
    "    graph.build_plans()\n",
    "    \n",
    "    return graph.serialize()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### De-serialization function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def deserialize(payload):\n",
    "    \n",
    "    graph = cudnn.pygraph()\n",
    "    \n",
    "    graph.deserialize(payload)\n",
    "    \n",
    "    return graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "####  running the execution plan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "check_support()\n",
    "\n",
    "data = serialize()\n",
    "\n",
    "deserialized_graph  = deserialize(data)\n",
    "\n",
    "workspace = torch.empty(deserialized_graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)\n",
    "\n",
    "variant_pack = {\n",
    "    UIDs.Q_UID.value: q_gpu,\n",
    "    UIDs.K_UID.value: k_gpu,\n",
    "    UIDs.V_UID.value: v_gpu,\n",
    "    UIDs.O_UID.value: o_gpu,\n",
    "    UIDs.STATS_UID.value: stats_gpu,\n",
    "}\n",
    "\n",
    "deserialized_graph.execute(variant_pack, workspace)\n",
    "\n",
    "torch.cuda.synchronize()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "build_thunder",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
